Predict the future? Let’s just see if we can predict what happened last year!

  • As you know, Facebook’s prophet model is easy to use and works well with clear seasonality
  • But it doesn’t have any built-in way to deal with confounders
  • This is a major limitation if you want a short-term, potentially high-variance forecast
  • We also know that interactions exist in most data sets–prophet is good at fitting curves, but is walking past $100 bill covariates

Let’s simulate some data

Let’s assume we have pretty predictable seasonal trends in subscribers.

Expand for code
# load libraries
library(tidyverse)
library(lubridate)
library(usaidplot)
library(ggridges)
library(prophet)
library(bsts)
library(gt)
library(gtExtras)

# set seed
set.seed(6292025)

# Simulate weekly dates
n_weeks <- 156
weeks <- seq.Date(from = as.Date("2022-01-02"), by = "week", length.out = n_weeks)

# Categories
sports <- c("baseball", "basketball", "football", "soccer")
sport_category <- c("school", "club", "recreation")
age_group <- c("U13", "13_18", "18plus")

# Seasonal trend function with peaks
seasonal_trend <- function(week_date) {
  week_of_year <- isoweek(week_date)
  spring_peak <- exp(-0.03 * (week_of_year - 18)^2) * 800
  fall_bump <- exp(-0.02 * (week_of_year - 42)^2) * 200
  500 + spring_peak + fall_bump
}

# Simulate Data with strong correlation for baseball U13
sim_data <- map_dfr(weeks, function(week) {
  base_n <- round(seasonal_trend(week) + rnorm(1, 0, 50))
  if (base_n <= 0) return(tibble())

  tibble(
    sport = sample(sports, base_n, replace = TRUE, prob = c(0.5, 0.2, 0.15, 0.15)),
    category = sample(sport_category, base_n, replace = TRUE),
    age_group = sample(age_group, base_n, replace = TRUE, prob = c(0.5, 0.3, 0.2)),
    tenure_weeks = rpois(base_n, lambda = 12)
  ) |>
    mutate(
      week = week,
      subscribe_prob = case_when(
        sport == "baseball" & age_group == "U13" ~ 0.95
        , sport == "baseball" ~ 0.7
        , age_group == "U13" ~ 0.6
        , TRUE ~ 0.4
      ),
      is_subscribed = rbinom(n(), 1, subscribe_prob)
    ) |>
    filter(is_subscribed == 1) |>
    select(week, sport, category, age_group, tenure_weeks)
})


# Aggregate to weekly subscriber counts by group
summary_data <- sim_data |>
  count(week, sport, age_group) |>
  pivot_wider(names_from = c(sport, age_group), values_from = n, values_fill = 0)

# Correct row sum using select and across properly
total_cols <- summary_data |> select(-week) |> colnames()

summary_data <- summary_data |>
  mutate(total_subs = rowSums(across(all_of(total_cols))))

ggplot(summary_data, aes(x = week, y = total_subs)) +
  geom_line(color = "#5F9EA0", linewidth = 1) +
  labs(title = "Simulated Subscribers with Seasonal Peaks", x = "", y = " ") +
  usaid_plot()

Let’s also add in some confounding

Our simulated data also include higher variation for baseball subscribers, particularly for the Under 13 age group.

Expand for code
sim_data |>
    count(week, sport, age_group) |> 
    mutate(age_group = case_when(age_group == "U13" ~ "Under 13"
                                 , age_group == "13_18" ~ "13-18"
                                 , age_group == "18plus" ~ "18+"
    )
    , sport = Hmisc::capitalize(sport)
    ) |> 
    mutate(age_group = factor(age_group, levels = c("Under 13", "13-18", "18+"))) |> 
    ggplot() +
    geom_density_ridges(aes(x = n, y = fct_rev(age_group), fill = sport), alpha = 0.6) +
    usaid_plot() +
    scale_fill_viridis_d() +
    theme(legend.position = "top") +
    labs(x = "", y = "" , title = "Distribution of Subscribers by Age Group and Sport")

Forecasting

Prophet does well on our simulated data even if we know there is some predictive value in age group and sport.

Expand for code
# create a train and test split where we hold out the last year for forecast testing

train_data <- summary_data |> filter(week <= max(week) - weeks(52))
test_data <- summary_data |>  filter(week > max(week) - weeks(52))

# fit our prophet model
prophet_data <- summary_data |>  select(ds = week, y = total_subs)
prophet_model <- prophet(prophet_data)

future <- make_future_dataframe(prophet_model, periods = 52, freq = "week")

forecast <- predict(prophet_model, future) |> 
  mutate(ds = as_date(ds))

summary_data |> 
  select(week, total_subs) |> 
  left_join(
forecast |> filter(year(ds)>2023) |> 
  select("week" = ds, yhat, yhat_lower, yhat_upper)) |> 
ggplot(aes(x = week)) + 
  geom_line(aes(y = total_subs), col = "#440154FF") +
  geom_line(aes(y = yhat), col = "#22A884FF", linetype = "dashed") +
  geom_ribbon(aes(y = yhat, ymin = yhat_lower, ymax = yhat_upper), fill = "#22A884FF", alpha =0.3) +
  usaid_plot() + 
  labs(x = "", y = "Total Subscribers", title = "Prophet Model Against Acutals for the last year in data")

Prophet’s performance stats are really strong

Even without accounting for covariates, Prophet does well.

Expand for code
df_cv <- cross_validation(prophet_model, initial = 100, period = 1, horizon = 52, units = 'weeks')

df_perf <- tibble(performance_metrics(df_cv))

df_perf |> 
  ggplot(aes(x = horizon, y = rmse)) +
  geom_line(color = "#22A884FF", size = 1) +
  geom_hline(yintercept =  mean(df_perf$rmse), linetype = "dashed") +
  usaid_plot() +
  labs(x = "Number of forecast days", y = "RMSE", title = "Prophet's error is low even for dates a year out"
      , subtitle = "Indeed, so low, I kind of worry about overfitting"
      , caption = "Dashed line shows mean RMSE")

Another approach: BSTS

  • Developed by Google in 2013 for nowcasting
  • Similar to Prophet in that it decomposes time series into different elements
  • BUT unlike Prophet, it allows for incorporation of covariates and lets users understand the relationships between different variables and the outcome of interest
  • This last point is important for product segmentation and for understanding why a forecasted trend is expected

Let’s look under the hood at BSTS and then show some results!

BSTS uses “states” to forecast. There are various states that are summed to actually estimate the forecast, but the main ones are:

  • Trend: captures the overall direction

  • Seasonal: captures repeating patterns

  • Regression: captures covariates, important for shocks/outliers

From Scott and Varian 2013, Google

Let’s look under the hood at BSTS and then show some results!

Expand for code
# let's show an example of spike and slab

tibble(x= abs(rnorm(1000, mean = 2, 0.5))) |> ggplot(aes(x)) + geom_density(lwd = 1, fill = "#FDE725FF", alpha = 0.4) +
    geom_vline(xintercept = 0, lwd = 1, linetype = "dashed") +
    usaid_plot() + xlim(-1, 4) +
  geom_text(data = data.frame(x = c(-0.151347523196993, 1.83561826421733 ),
y = c(0.629476876346744, 0.551140209769703),
label = c("Spike", "Slab")),
mapping = aes(x = x, y = y, label = label),
family = "Gill Sans", inherit.aes = FALSE, size = 5) +
  labs(x = "", y = "")
  • Both Prophet and BSTS decompose time series into parts like trend and seasonality.
  • Prophet assumes fixed, repetition over a trend with change points made explicit, while BSTS states adapt over time based on data
  • BSTS uses spike and slab priors in the regression component–this helps identify confounders and avoid overfitting

Forecasting with BSTS

We get similar results but with less uncertainty around the forecast compared to our prophet model.

Expand for code
# Split training and testing (last year weeks as holdout)
train_data <- summary_data |> filter(week <= max(week) - weeks(52))
test_data <- summary_data |> filter(week > max(week) - weeks(52))

# Set up BSTS
ss <- AddLocalLevel(list(), train_data$total_subs)
ss <- AddSeasonal(ss, train_data$total_subs, nseasons = 52)

bsts_model <- bsts(total_subs ~ . - week, state.specification = ss, niter = 2000, data = train_data, ping = -1)

# predict last 52 weeks using true covariates
test_cov <- test_data |> select(-total_subs)
bsts_forecast <- predict(bsts_model, horizon = 52, newdata = test_cov)

bsts_df <- tibble(mean = bsts_forecast$mean
                  , lower = bsts_forecast$interval[1, ]
                  , upper = bsts_forecast$interval[2, ]
                  , week = test_data$week
)

ggplot(summary_data, aes(x = week)) + 
  geom_line(aes(y = total_subs), col = "#440154FF") +
  geom_line(data =bsts_df,  aes(y = mean), col = "#22A884FF", linetype = "dashed") +
  geom_ribbon( data = bsts_df, aes(y = mean, ymin = lower, ymax = upper), fill = "#22A884FF", alpha =0.3) +
  usaid_plot() + 
  labs(x = "", y = "Total Subscribers", title = "BSTS Model Against Acutals for the last year in data"
       )

BSTS Performance

Expand for code
horizon_weeks <- 52   
initial_window <- 104 
period <- 1           
niter <- 1500 

# creat results Holder
rmse_results <- list()

# Loop through rolling forecast origins
total_obs <- nrow(summary_data)

for (start_idx in seq(initial_window, total_obs - horizon_weeks, by = period)) {
  
  # split train/test sets
  train_data <- summary_data |> slice(1:start_idx)
  test_data <- summary_data |> slice((start_idx + 1):(start_idx + horizon_weeks))
  
  # Specify BSTS state components same as above
  ss <- AddLocalLevel(list(), train_data$total_subs)
  ss <- AddSeasonal(ss, train_data$total_subs, nseasons = 52)
  
  # Fit BSTS model
  bsts_fit <- bsts(total_subs ~ . - week, state.specification = ss, niter = niter, data = train_data, ping = -1)
  
  # Forecast with covariates
  new_covariates <- test_data |> select(-total_subs)
  forecast_out <- predict(bsts_fit, horizon = horizon_weeks, newdata = new_covariates)
  
  # Calculate RMSE per horizon
  rmse_vec <- sqrt((forecast_out$mean - test_data$total_subs)^2)
  
  rmse_results[[length(rmse_results) + 1]] <- tibble(
    horizon = 1:horizon_weeks,
    rmse = rmse_vec
  )
  
}

# Aggregate RMSE by horizon across folds
rmse_summary <- bind_rows(rmse_results) |>
  summarise(
    avg_rmse = mean(rmse, na.rm = TRUE),
    .by = horizon
  )

# Graph it
ggplot(rmse_summary, aes(x = horizon, y = avg_rmse)) +
  geom_line(color = "#22A884FF", size = 1) +
  geom_hline(yintercept = mean(rmse_summary$avg_rmse), linetype = "dashed") + 
  usaid_plot() +
  labs(x = "Number of forecast days", y = "RMSE", title = "BSTS's error is low even for dates a year out"
      , subtitle = "While overfitting may still be an issue,\nwe know we've minimized it through BSTS priors"
      , caption = "Dashed line shows mean RMSE")

Checking out the BSTS components

Expand for code
# this bit is taken from, same with the other figure in the next chunk https://multithreaded.stitchfix.com/blog/2016/04/21/forget-arima/

# define a burn, i.e., a percentage of runs to discard
burn <- SuggestBurn(0.1, bsts_model)


components_withreg <- cbind.data.frame(
  colMeans(bsts_model$state.contributions[-(1:burn),"trend",]),
  colMeans(bsts_model$state.contributions[-(1:burn),"seasonal.52.1",]),
  colMeans(bsts_model$state.contributions[-(1:burn),"regression",]),
  as.Date(time(train_data$week)))  


names(components_withreg) <- c("Trend", "Seasonality", "Regression", "Date")
components_withreg <- reshape2::melt(components_withreg, id.vars="Date")
names(components_withreg) <- c("Date", "Component", "Value")

ggplot(data=components_withreg, aes(x=Date, y=Value)) + geom_line(aes(color = Component)) + 
  usaid_plot() + theme(legend.title = element_blank()) + labs(x = "", y = "") +
  facet_grid(Component ~ ., scales="free") 

Expand for code
inclusionprobs <- reshape2::melt(colMeans(bsts_model$coefficients[-(1:burn),] != 0))

inclusionprobs$Variable <- as.character(row.names(inclusionprobs))

ggplot(data=inclusionprobs, aes(x=reorder(Variable, value), y=value)) + 
    geom_bar(stat="identity", position="identity") + 
  usaid_plot() +
    theme(axis.text.x=element_text(angle = -90, hjust = 0)) + 
    labs(x = "", y = "", title = "Inclusion probabilities"
         , subtitle = "Since we do 2,000 MCMC runs,\nwe can see what covariates are most likely to be included\nand which get regularized away")

Comparing Prophet and BSTS

Expand for code
# compare forecasts for the last year
forecast_last52 <- forecast |> filter(ds > max(train_data$week))

results <- test_data |>
  left_join(forecast_last52 |> select(ds, prophet_yhat = yhat), by = c("week" = "ds")) |>
  mutate(bsts_mean = bsts_forecast$mean)

# Accuracy Metrics
rmse <- function(actual, predicted) sqrt(mean((actual - predicted)^2))
mae <- function(actual, predicted) mean(abs(actual - predicted))
mape <- function(actual, predicted) mean(abs((actual - predicted)/actual)) * 100

metrics <- tibble(
  Model = c("Prophet", "BSTS"),
  RMSE = c(rmse(results$total_subs, results$prophet_yhat), rmse(results$total_subs, results$bsts_mean)),
  MAE = c(mae(results$total_subs, results$prophet_yhat), mae(results$total_subs, results$bsts_mean)),
  MAPE = c(mape(results$total_subs, results$prophet_yhat), mape(results$total_subs, results$bsts_mean))
)

# Display gt table
gt(metrics) |>
  fmt_number(columns = c(RMSE, MAE, MAPE), decimals = 2) |>
  tab_header(title = "Forecast Stats")
  • Overall, BSTS performs better especially when we know there are confounders
  • BSTS does run slower than Prophet because of the MCMC process
  • We also haven’t added any informative priors, which could help improve the BSTS forecasts
  • The key takeaway: BSTS allows us to unpack results in a way that Prophet does not and performs really well!
Forecast Stats
Model RMSE MAE MAPE
Prophet 35.56 27.79 6.54
BSTS 15.75 11.74 2.50